{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [dev] PyMC4 Design overview. Generators, Coroutines and all the things"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a brief introduction of Coroutines you can first read [PEP0342](https://www.python.org/dev/peps/pep-0342/), but we will cover most of the basics here with a practical example. Here we will make a draft of a PPL on top of `tfp`. The challenge is to use dynamic graph building and designing a flexible framework."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## A simple case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow_probability as tfp\n",
    "import tensorflow as tf\n",
    "from tensorflow_probability import distributions as tfd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's first look at \"HOW-TO PPL\". Essentially, we want to implement a probabilisitc program (think of a Bayesian model that you usually represent in a directed acyclic graph) that does 2 things: \n",
    "\n",
    "1. forward sampling to get prior (predictive) samples; \n",
    "2. reverse evaluation on some inputs to compute the log-probability (if we conditioned on the observed, we are evaluating the unnormalized posterior distribution of the unknown random variables). \n",
    "\n",
    "Specifically for computing the log_prob, we need to keep track of the dependency. For example, if we have something simple like `x ~ Normal(0, 1), y ~ Normal(x, 1)`, in the reverse mode (goal 2 from above) we need to swap `x` with the input `x` (either from user or programmatically) to essentially do `log_prob = Normal(0, 1).log_prob(x_input) + Normal(x_input, 1).log_prob(y_input)`.\n",
    "\n",
    "There are a few approaches to this problem. For example, in PyMC3 with a static graph backend (theano), we are able to write things in a declarative way and use the static computational graph to keep track of the dependences. This means we are able to do `log_prob = x.log_prob(x) + y.log_prob(y)`, as the value representation of the random variables `x` and `y` are \"swap-in\" with some input at runtime. With a dynamic graph we run into problems as we write the model in a forward mode -- we essentially lose track of the dependence in the reverse calling mode. We need to either keep track of the graph representation ourselves, or write a function that could be reevaluated to make sure we have the right dependencies. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This ideally should look like this"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(x):\n",
    "    scale = tfd.HalfCauchy(0, 1)\n",
    "    coefs = tfd.Normal(tf.zeros(x.shape[1]), 1, )\n",
    "    predictions = tfd.Normal(tf.linalg.matvec(x, coefs), scale)\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But this function will not work (you can try it yourself), because there is no random variable concept in `tfp`, meaning that you cannot do `RV_a.log_prob(RV_a)` (yes, just think of Random Variable in a PPL as a tensor/array-like object that you can do computation and a log_prob method that we can evaluate it on itself."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model(tf.random.normal((100, 10)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generating a log_prob from this model also wont work, because the computation is different than the function we have written down above.\n",
    "\n",
    "What we want here is to track function evaluation at runtime, depending on the context (goal 1 or goal 2 from above).\n",
    "\n",
    "The very first way to cope with is was writing a wrapper over a distribution object. This wrapper was intended to catch a call to the distribution and use context to figure out what to do: for goal 1, we draw a sample from a random variable and plug the concrete value into the downstream dependencies; for goal 2 we got the concrete value and evaluate it with the respective random variable, and also plug the concrete value into the downstream dependencies. Here we use Coroutines from Python to have dynamic control flow that could achieve this kind of deferred execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(x):\n",
    "    scale = yield tfd.HalfCauchy(0, 1)\n",
    "    coefs = yield tfd.Normal(tf.zeros(x.shape[1]), 1, )\n",
    "    predictions = yield tfd.Normal(tf.linalg.matvec(x, coefs), scale)\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we evaluate the model as expected but `yield` allows us to defer the program execution. But before evaluating this function, let's figure out what does yield do."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generator(message):\n",
    "    print(\"I am a generator and I yield\", message)\n",
    "    response = yield message\n",
    "    print(\"I am a generator and I got\", response)\n",
    "    return \"goodbye\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = generator(\"(generators are cool)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "I am a generator and I yield (generators are cool)\n"
     ]
    }
   ],
   "source": [
    "mes = g.send(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(generators are cool)\n"
     ]
    }
   ],
   "source": [
    "print(mes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "I am a generator and I got (Indeed, generators are cool)\n"
     ]
    },
    {
     "ename": "StopIteration",
     "evalue": "goodbye",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mStopIteration\u001b[0m                             Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-18-fd1ede79425c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"(Indeed, generators are cool)\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mStopIteration\u001b[0m: goodbye"
     ]
    }
   ],
   "source": [
    "g.send(\"(Indeed, generators are cool)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What has happened here:\n",
    "\n",
    "* we had a simple generator and were able to communicate with it via `send`\n",
    "* after `send` is called (first time requires it to have `None` argument) generator goes to the next `yield` expression and yields what it is asked to yield.\n",
    "* as a return value from `send` we have this exact message from `yield message`\n",
    "* we set the lhs of `response = yield message` with next `send` and no earlier\n",
    "* after generator has no `yield` statements left and finally reaches `return`, it raises `StopIteration` with return value as a first argument\n",
    "\n",
    "Now we are ready to evaluate our model by hand"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = dict(dists=dict(), samples=dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "state[\"input\"] = tf.random.normal((3, 10))\n",
    "m = model(state[\"input\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "scale_dist = next(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tfp.distributions.HalfCauchy(\"HalfCauchy\", batch_shape=[], event_shape=[], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "print(scale_dist)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "which means, we are here\n",
    "\n",
    "```python\n",
    "def model(x):\n",
    "    scale = yield tfd.HalfCauchy(0, 1) # <--- HERE\n",
    "    coefs = yield tfd.Normal(tf.zeros(x.shape[1]), 1, )\n",
    "    ...\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What can we do with this distribution? We can choose forward sampling (in this case we sample from the state-less distribution `HalfCauchy(0, 1)`). But we need it to be used by user seamlessly later on regardless of the context (goal 1 or 2 above). On the model side, we need to store intermediate values and its associated distributions (hey! that's a random variable!)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert scale_dist.name not in state[\"dists\"]\n",
    "state[\"samples\"][scale_dist.name] = scale_dist.sample()\n",
    "state[\"dists\"][scale_dist.name] = scale_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "coefs_dist = m.send(state[\"samples\"][scale_dist.name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tfp.distributions.Normal(\"Normal\", batch_shape=[10], event_shape=[], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "print(coefs_dist)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "def model(x):\n",
    "    scale = yield tfd.HalfCauchy(0, 1)\n",
    "    coefs = yield tfd.Normal(tf.zeros(x.shape[1]), 1, ) # <--- WE ARE NOW HERE\n",
    "    predictions = yield tfd.Normal(tf.linalg.matvec(x, coefs), scale)\n",
    "    return predictions\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We do the same thing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert coefs_dist.name not in state[\"dists\"]\n",
    "state[\"samples\"][coefs_dist.name] = coefs_dist.sample()\n",
    "state[\"dists\"][coefs_dist.name] = coefs_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds_dist = m.send(state[\"samples\"][coefs_dist.name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tfp.distributions.Normal(\"Normal\", batch_shape=[3], event_shape=[], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "print(preds_dist)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "def model(x):\n",
    "    scale = yield tfd.HalfCauchy(0, 1)\n",
    "    coefs = yield tfd.Normal(tf.zeros(x.shape[1]), 1, )\n",
    "    predictions = yield tfd.Normal(tf.linalg.matvec(x, coefs), scale) # <--- NOW HERE\n",
    "    return predictions\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are now facing predictive distribution. Here we have several options:\n",
    "* sample from it: we get prior predictive\n",
    "* set custom values instead of sample, essentially conditioning on data. We might be interested in this to compute unnormalized posterior\n",
    "* replace it with another distribution, arbitrary magic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-29-67c730887c51>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mpreds_dist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"dists\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"samples\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpreds_dist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreds_dist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"dists\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpreds_dist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpreds_dist\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "assert preds_dist.name not in state[\"dists\"]\n",
    "state[\"samples\"][preds_dist.name] = tf.zeros(preds_dist.batch_shape)\n",
    "state[\"dists\"][preds_dist.name] = preds_dist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Gotcha, we found duplicated names in our toy graphical model. We can easily tell our user to rewrite the model to get rid of duplicate names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "We found duplicate names in your cool model: Normal, so far we have other variables in the model, {'Normal', 'HalfCauchy'}",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-30-0eaccbf92d8b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0;34m\"We found duplicate names in your cool model: {}, \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \"so far we have other variables in the model, {}\".format(\n\u001b[0;32m----> 4\u001b[0;31m         \u001b[0mpreds_dist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"dists\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m     )\n\u001b[1;32m      6\u001b[0m ))\n",
      "\u001b[0;32m<ipython-input-2-cd13f203443c>\u001b[0m in \u001b[0;36mmodel\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0mscale\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mtfd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mHalfCauchy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mcoefs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mtfd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     \u001b[0mpredictions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mtfd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcoefs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mpredictions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: We found duplicate names in your cool model: Normal, so far we have other variables in the model, {'Normal', 'HalfCauchy'}"
     ]
    }
   ],
   "source": [
    "m.throw(RuntimeError(\n",
    "    \"We found duplicate names in your cool model: {}, \"\n",
    "    \"so far we have other variables in the model, {}\".format(\n",
    "        preds_dist.name, set(state[\"dists\"].keys()), \n",
    "    )\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The good thing is that we *communicate* with user, and can give meaningful exceptions with few pain.\n",
    "\n",
    "The correct model should look like this:\n",
    "\n",
    "```python\n",
    "def model(x):\n",
    "    scale = yield tfd.HalfCauchy(0, 1)\n",
    "    coefs = yield tfd.Normal(tf.zeros(x.shape[1]), 1, )\n",
    "    predictions = yield tfd.Normal(tf.linalg.matvec(x, coefs), scale, name=\"Normal_1\") # <--- HERE we asked out user to change the name\n",
    "    return predictions\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's set all the names according to the new model and interact with user again using the same model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.gi_running"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our generator is now at the end of its execution - we can't interact with it any more. Let's create a new one and reevaluate with same sampled values (A hint how to get the desired `logp` function)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(x):\n",
    "    scale = yield tfd.HalfCauchy(0, 1)\n",
    "    coefs = yield tfd.Normal(tf.zeros(x.shape[1]), 1, )\n",
    "    predictions = yield tfd.Normal(tf.linalg.matvec(x, coefs), scale, name=\"Normal_1\") # <--- HERE we asked out user to change the name\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tfp.distributions.HalfCauchy(\"HalfCauchy\", batch_shape=[], event_shape=[], dtype=float32)\n",
      "tfp.distributions.Normal(\"Normal\", batch_shape=[10], event_shape=[], dtype=float32)\n",
      "tfp.distributions.Normal(\"Normal_1\", batch_shape=[3], event_shape=[], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "m = model(state[\"input\"])\n",
    "print(m.send(None))\n",
    "print(m.send(state[\"samples\"][\"HalfCauchy\"]))\n",
    "print(m.send(state[\"samples\"][\"Normal\"]))\n",
    "try:\n",
    "    m.send(tf.zeros(state[\"input\"].shape[0]))\n",
    "except StopIteration as e:\n",
    "    stop_iteration = e\n",
    "else:\n",
    "    raise RuntimeError(\"No exception met\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([0. 0. 0.], shape=(3,), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "print(stop_iteration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of returning some value in the last `send`, generator raises `StopIteration` because it is exhausted and reached the `return` statement (no more `yield` met). As explained (and checked here) in [PEP0342](https://www.python.org/dev/peps/pep-0342/), we have a return value inside"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Automate process above"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We all are lazy humans and cant stand doing repetitive things. In our model evaluation we followed pretty simple rules:\n",
    "* asserting name is not used\n",
    "* checking if we should sample or place a specific value instead\n",
    "* recording distributions and samples\n",
    "\n",
    "Next step is to make a function that does all this instead of us. In this tutorial let's keep it simple:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def interact(gen, state):\n",
    "    control_flow = gen()\n",
    "    return_value = None\n",
    "    while True:\n",
    "        try:\n",
    "            dist = control_flow.send(return_value)\n",
    "            if dist.name in state[\"dists\"]:\n",
    "                control_flow.throw(RuntimeError(\n",
    "                    \"We found duplicate names in your cool model: {}, \"\n",
    "                    \"so far we have other variables in the model, {}\".format(\n",
    "                        dist.name, set(state[\"dists\"].keys()), \n",
    "                    )\n",
    "                ))\n",
    "            if dist.name in state[\"samples\"]:\n",
    "                return_value = state[\"samples\"][dist.name]\n",
    "            else:\n",
    "                return_value = dist.sample()\n",
    "                state[\"samples\"][dist.name] = return_value\n",
    "            state[\"dists\"][dist.name] = dist\n",
    "        except StopIteration as e:\n",
    "            if e.args:\n",
    "                return_value = e.args[0]\n",
    "            else:\n",
    "                return_value = None\n",
    "            break\n",
    "    return return_value, state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This implementation assumes no arg generator, we make things just simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds, state = interact(lambda: model(tf.random.normal((3, 10))), state=dict(dists=dict(), samples=dict()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'dists': {'HalfCauchy': <tfp.distributions.HalfCauchy 'HalfCauchy' batch_shape=[] event_shape=[] dtype=float32>,\n",
       "  'Normal': <tfp.distributions.Normal 'Normal' batch_shape=[10] event_shape=[] dtype=float32>,\n",
       "  'Normal_1': <tfp.distributions.Normal 'Normal_1' batch_shape=[3] event_shape=[] dtype=float32>},\n",
       " 'samples': {'HalfCauchy': <tf.Tensor: shape=(), dtype=float32, numpy=35.769157>,\n",
       "  'Normal': <tf.Tensor: shape=(10,), dtype=float32, numpy=\n",
       "  array([-0.41907588,  0.8893132 , -0.00785866,  2.1496706 , -0.21485494,\n",
       "         -0.79587907, -0.9849972 ,  1.261421  ,  0.40662226, -0.38024214],\n",
       "        dtype=float32)>,\n",
       "  'Normal_1': <tf.Tensor: shape=(3,), dtype=float32, numpy=array([  0.3047135, -33.16318  ,  11.798283 ], dtype=float32)>}}"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3,), dtype=float32, numpy=array([  0.3047135, -33.16318  ,  11.798283 ], dtype=float32)>"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We get all the things as expected. To calculate `logp` you just iterate over distributions and match them with the correspondig values. But let's dive deeper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## One level deeper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Recall the motivating example from [PR#125](https://github.com/pymc-devs/pymc4/pull/125)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Horseshoe(mu=0, tau=1., s=1., name=None):\n",
    "    with tf.name_scope(name):\n",
    "        scale = yield tfd.HalfCauchy(0, s, name=\"scale\")\n",
    "        noise = yield tfd.Normal(0, tau, name=\"noise\")\n",
    "        return scale * noise + mu\n",
    "\n",
    "\n",
    "def linreg(x):\n",
    "    scale = yield tfd.HalfCauchy(0, 1, name=\"scale\")\n",
    "    coefs = yield Horseshoe(tf.zeros(x.shape[1]), name=\"coefs\")\n",
    "    predictions = yield tfd.Normal(tf.linalg.matvec(x, coefs), scale, name=\"predictions\")\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'generator' object has no attribute 'name'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-41-417ed8050ceb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minteract\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlinreg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdists\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msamples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-36-4ce07695efeb>\u001b[0m in \u001b[0;36minteract\u001b[0;34m(gen, state)\u001b[0m\n\u001b[1;32m      5\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m             \u001b[0mdist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcontrol_flow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreturn_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m             \u001b[0;32mif\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"dists\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      8\u001b[0m                 control_flow.throw(RuntimeError(\n\u001b[1;32m      9\u001b[0m                     \u001b[0;34m\"We found duplicate names in your cool model: {}, \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'generator' object has no attribute 'name'"
     ]
    }
   ],
   "source": [
    "preds, state = interact(lambda: linreg(tf.random.normal((3, 10))), state=dict(dists=dict(), samples=dict()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Oooups, we have a type error. What we want is a nested model, but nesting models is something different from a plain generator. As we have our model being a generator itself, the return value of `Horseshoe(tf.zeros(x.shape[1]), name=\"coefs\")` is a generator. Of course this generator has no name attribute. Okay, we can ask user to use `yield from` construction to generate from the generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linreg_ugly(x):\n",
    "    scale = yield tfd.HalfCauchy(0, 1, name=\"scale\")\n",
    "    coefs = yield from Horseshoe(tf.zeros(x.shape[1]), name=\"coefs\")\n",
    "    predictions = yield tfd.Normal(tf.linalg.matvec(x, coefs), scale, name=\"predictions\")\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds, state = interact(lambda: linreg_ugly(tf.random.normal((3, 10))), state=dict(dists=dict(), samples=dict()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Okay, we passed this thing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'scale': <tfp.distributions.HalfCauchy 'scale' batch_shape=[] event_shape=[] dtype=float32>,\n",
       " 'coefs_scale': <tfp.distributions.HalfCauchy 'coefs_scale' batch_shape=[] event_shape=[] dtype=float32>,\n",
       " 'coefs_noise': <tfp.distributions.Normal 'coefs_noise' batch_shape=[] event_shape=[] dtype=float32>,\n",
       " 'predictions': <tfp.distributions.Normal 'predictions' batch_shape=[3] event_shape=[] dtype=float32>}"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state[\"dists\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We got nesting models working, but it requires `yield from`. This is UGLY and potentially confusing for user. Fortunately, we can rewrite out `interact` function to accept nested models in a few lines, and let the Python do the task for us."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "import types\n",
    "def interact_nested(gen, state):\n",
    "    # for now we should check input type\n",
    "    if not isinstance(gen, types.GeneratorType):\n",
    "        control_flow = gen()\n",
    "    else:\n",
    "        control_flow = gen\n",
    "\n",
    "    return_value = None\n",
    "    while True:\n",
    "        try:\n",
    "            dist = control_flow.send(return_value)\n",
    "            # this makes nested models possible\n",
    "            if isinstance(dist, types.GeneratorType):\n",
    "                return_value, state = interact_nested(dist, state)\n",
    "                # ^ in a few lines of code, go recursive\n",
    "            else:\n",
    "                if dist.name in state[\"dists\"]:\n",
    "                    control_flow.throw(RuntimeError(\n",
    "                        \"We found duplicate names in your cool model: {}, \"\n",
    "                        \"so far we have other variables in the model, {}\".format(\n",
    "                            dist.name, set(state[\"dists\"].keys()), \n",
    "                        )\n",
    "                    ))\n",
    "                if dist.name in state[\"samples\"]:\n",
    "                    return_value = state[\"samples\"][dist.name]\n",
    "                else:\n",
    "                    return_value = dist.sample()\n",
    "                    state[\"samples\"][dist.name] = return_value\n",
    "                state[\"dists\"][dist.name] = dist\n",
    "        except StopIteration as e:\n",
    "            if e.args:\n",
    "                return_value = e.args[0]\n",
    "            else:\n",
    "                return_value = None\n",
    "            break\n",
    "    return return_value, state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "remember we had problems here:\n",
    "```python\n",
    "preds, state = interact(lambda: linreg(tf.random.normal((3, 10))), state=dict(dists=dict(), samples=dict()))\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Additionally we can specify the observed variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds, state = interact_nested(lambda: linreg(tf.random.normal((3, 10))), state=dict(dists=dict(), samples={\"predictions/\":tf.zeros(3)}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'scale': <tfp.distributions.HalfCauchy 'scale' batch_shape=[] event_shape=[] dtype=float32>,\n",
       " 'coefs_scale': <tfp.distributions.HalfCauchy 'coefs_scale' batch_shape=[] event_shape=[] dtype=float32>,\n",
       " 'coefs_noise': <tfp.distributions.Normal 'coefs_noise' batch_shape=[] event_shape=[] dtype=float32>,\n",
       " 'predictions': <tfp.distributions.Normal 'predictions' batch_shape=[3] event_shape=[] dtype=float32>}"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state[\"dists\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'predictions/': <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,\n",
       " 'scale': <tf.Tensor: shape=(), dtype=float32, numpy=0.9951895>,\n",
       " 'coefs_scale': <tf.Tensor: shape=(), dtype=float32, numpy=0.15956731>,\n",
       " 'coefs_noise': <tf.Tensor: shape=(), dtype=float32, numpy=0.5349592>,\n",
       " 'predictions': <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.2006491 ,  0.08688551, -0.83539486], dtype=float32)>}"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "state[\"samples\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Cool, we've finished the central idea behind PyMC4 core engine. There is some extra stuff to do to make `evaluate_nested` really powerful\n",
    "\n",
    "* resolve transforms\n",
    "* resolve reparametrizations\n",
    "* variational inference\n",
    "* better error messages\n",
    "* lazy returns in posterior predictive mode\n",
    "\n",
    "Some of this functionality may be found in the corresponding [PR#125](https://github.com/pymc-devs/pymc4/pull/125)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
